from __future__ import annotations

from dataclasses import dataclass
import os
from typing import Callable, Protocol, Sequence, Tuple, Optional, Dict, List
import random
import concurrent.futures


class LLMClient(Protocol):
    def chat(
        self,
        messages: list[dict],
        temperature: float,
        max_tokens: int,
        reasoning_max_tokens: Optional[int] = None,
        reasoning_exclude: Optional[bool] = None,
    ) -> tuple[str, int]:
        ...


@dataclass
class CoThinkerConfig:
    num_agents: int = 6
    num_rounds: int = 3
    num_references: int = 3
    rewiring_prob: float = 0.3
    enable_style_generator: bool = True
    summarizer: str = "tms"  # one of {"tms", "individual", "none"}
    max_tokens_style: int = 8192
    max_tokens_turn: int = 8192
    max_tokens_summarize: int = 8192
    max_tokens_synth: int = 8192
    init_temperature: float = 0.25
    followup_temperature: float = 0.25
    small_world: bool = True
    ring_k: int = 2  # number of neighbors on each side in initial ring (total degree ~2*ring_k)
    reference_mode: str = "none"  # none | neighbor-embed | neighbor | global-embed | global
    embed_model: str = "all-MiniLM-L6-v2"
    embed_device: Optional[str] = None


class CoThinkerEngine:
    def __init__(
        self,
        client_factory: Callable[[], LLMClient],
        prompt_registry: "PromptRegistry",
        config: CoThinkerConfig,
    ) -> None:
        self.client_factory = client_factory
        self.prompts = prompt_registry
        self.cfg = config
        self._embedder = None

    def _dbg(self, msg: str) -> None:
        flag = os.environ.get("COTHINKER_DEBUG", "").lower()
        if flag in {"1", "true", "yes", "on"}:
            print(f"[CoThinker] {msg}", flush=True)

    def _get_embedder(self):
        if self._embedder is None and ("embed" in self.cfg.reference_mode):
            try:
                from .components.embedding import EmbeddingConfig, SentenceTransformerEmbedder
                self._embedder = SentenceTransformerEmbedder(
                    EmbeddingConfig(model_name=self.cfg.embed_model, device=self.cfg.embed_device)
                )
            except Exception:
                self._embedder = None
        return self._embedder

    def _gen_styles(self, task: str, client: LLMClient) -> tuple[list[str], int, str]:
        if not self.cfg.enable_style_generator:
            return ([""] * self.cfg.num_agents), 0, ""
        prompt = self.prompts.style_orchestrator(task)
        out, tok = client.chat(
            messages=[{"role": "user", "content": prompt}],
            temperature=self.cfg.init_temperature,
            max_tokens=self.cfg.max_tokens_style,
            reasoning_max_tokens=0,
            reasoning_exclude=True,
        )
        styles = []
        for line in out.splitlines():
            line = line.strip()
            if not line:
                continue
            if line[0].isdigit():
                # strip leading numbering like "1. ..."
                split_idx = line.find(" ")
                styles.append(line[split_idx + 1 :] if split_idx > 0 else line)
            else:
                styles.append(line)
        if len(styles) < self.cfg.num_agents:
            styles += [""] * (self.cfg.num_agents - len(styles))
        return styles[: self.cfg.num_agents], (int(tok or 0) if tok is not None else 0), prompt

    def _summarize(self, task: str, prev_summary: str, round_msgs: list[str], client: LLMClient) -> tuple[str, int, str]:
        if self.cfg.summarizer == "none":
            return prev_summary, 0, ""
        if self.cfg.summarizer == "individual":
            # naive per-agent summarization then merge
            pieces: list[str] = []
            total_tokens = 0
            for msg in round_msgs:
                if not msg:
                    continue
                prompt = self.prompts.individual_summarizer(task, msg)
                out, tok = client.chat(
                    messages=[{"role": "user", "content": prompt}],
                    temperature=self.cfg.followup_temperature,
                    max_tokens=self.cfg.max_tokens_summarize,
                    reasoning_max_tokens=0,
                    reasoning_exclude=True,
                )
                pieces.append(out.strip())
                try:
                    total_tokens += int(tok or 0)
                except Exception:
                    pass
            merged = (prev_summary + "\n" + "\n".join(pieces)).strip()
            return merged, total_tokens, "individual"
        # default: TMS style summarization
        prompt = self.prompts.tms(task, prev_summary, round_msgs)
        out, tok = client.chat(
            messages=[{"role": "user", "content": prompt}],
            temperature=self.cfg.followup_temperature,
            max_tokens=self.cfg.max_tokens_summarize,
            reasoning_max_tokens=0,
            reasoning_exclude=True,
        )
        return (out.strip() if out else prev_summary), (int(tok or 0) if tok is not None else 0), prompt

    def _synthesize(self, task: str, summary: str, last_round_msgs: list[str], client: LLMClient) -> tuple[str, int, str]:
        prompt = self.prompts.synthesizer(task, summary, last_round_msgs)
        out, tok = client.chat(
            messages=[{"role": "user", "content": prompt}],
            temperature=self.cfg.followup_temperature,
            max_tokens=self.cfg.max_tokens_synth,
        )
        return out.strip(), (int(tok or 0) if tok is not None else 0), prompt

    def run(self, question: dict) -> dict:
        task = question["turns"][-1]
        llm = self.client_factory()
        styles, style_tokens, style_prompt = self._gen_styles(task, llm)
        self._dbg(f"Start run: agents={self.cfg.num_agents}, rounds={self.cfg.num_rounds}, ref_mode={self.cfg.reference_mode}")
        self._dbg(f"Styles generated: tokens={style_tokens}, count={len(styles)}")

        summaries = ["Initial state. No discussion summary yet."]
        all_round_msgs: list[list[str]] = []
        round_traces: list[dict] = []

        # Build small-world topology once
        neighbors: list[list[int]] = [[] for _ in range(self.cfg.num_agents)]
        if self.cfg.small_world and self.cfg.num_agents > 1:
            # Start with ring lattice
            N = self.cfg.num_agents
            k = max(1, min(self.cfg.ring_k, (N - 1) // 2))
            for i in range(N):
                adj = set()
                for d in range(1, k + 1):
                    adj.add((i - d) % N)
                    adj.add((i + d) % N)
                neighbors[i] = sorted(list(adj))
            # Rewire edges with probability p (Watts–Strogatz style, undirected simplified)
            p = max(0.0, min(1.0, self.cfg.rewiring_prob))
            for i in range(N):
                for j in list(neighbors[i]):
                    if j < i:
                        continue  # handle each undirected edge once
                    if random.random() < p:
                        # remove current edge
                        neighbors[i].remove(j)
                        neighbors[j].remove(i)
                        # add a new edge to a random node not currently connected and not self
                        candidates = [x for x in range(N) if x != i and x not in neighbors[i]]
                        if candidates:
                            new_j = random.choice(candidates)
                            neighbors[i].append(new_j)
                            neighbors[new_j].append(i)
                neighbors[i] = sorted(list(set(neighbors[i])))

        # Round loop
        for t in range(self.cfg.num_rounds):
            self._dbg(f"Round {t+1}/{self.cfg.num_rounds} starting")
            msgs_this_round: list[str] = [""] * self.cfg.num_agents
            agents_trace: list[dict] = [{} for _ in range(self.cfg.num_agents)]

            def build_refs(i: int) -> list[str]:
                refs_local: list[str] = []
                if not all_round_msgs:
                    return refs_local
                last_msgs = all_round_msgs[-1]
                if self.cfg.small_world and neighbors[i]:
                    cand_idx = [j for j in neighbors[i] if j != i]
                else:
                    cand_idx = [j for j in range(self.cfg.num_agents) if j != i]
                cands = [(j, last_msgs[j]) for j in cand_idx if last_msgs[j]]
                mode = self.cfg.reference_mode
                if not cands:
                    return []
                if mode in ("neighbor", "global"):
                    return [m for _, m in cands][: self.cfg.num_references]
                embedder = self._get_embedder()
                if embedder is None:
                    return [m for _, m in cands][: self.cfg.num_references]
                prev = all_round_msgs[-1][i] if all_round_msgs else ""
                target_text = prev or (summaries[-1] if summaries else "")
                if not target_text:
                    return [m for _, m in cands][: self.cfg.num_references]
                try:
                    vec_target = embedder.encode([target_text])[0]
                    vec_cands = embedder.encode([m for _, m in cands])
                    import numpy as np
                    sims = (vec_cands @ vec_target)
                    order = np.argsort(-sims)
                    return [cands[int(k)][1] for k in order[: self.cfg.num_references]]
                except Exception:
                    self._dbg("Embedding selection failed; falling back to non-embed top-k")
                    return [m for _, m in cands][: self.cfg.num_references]

            def run_agent(i: int) -> Tuple[int, str, int, str, list[str], float]:
                style = styles[i] if i < len(styles) else ""
                prev = all_round_msgs[-1][i] if all_round_msgs else ""
                refs = build_refs(i)
                prompt = self.prompts.agent_turn(
                    agent_idx=i,
                    style=style,
                    task=task,
                    prev_summary=summaries[-1],
                    prev_answer=prev,
                    references=refs,
                    round_index=t,
                )
                temp = self.cfg.init_temperature if t == 0 else self.cfg.followup_temperature
                out, tok = llm.chat(
                    messages=[{"role": "user", "content": prompt}],
                    temperature=temp,
                    max_tokens=self.cfg.max_tokens_turn,
                )
                return i, out.strip(), (int(tok or 0) if tok is not None else -1), prompt, refs, temp

            with concurrent.futures.ThreadPoolExecutor(max_workers=self.cfg.num_agents) as ex:
                futures = {ex.submit(run_agent, i): i for i in range(self.cfg.num_agents)}
                for fut in concurrent.futures.as_completed(futures):
                    i, out_s, tok_i, prompt_i, refs_i, temp_i = fut.result()
                    msgs_this_round[i] = out_s
                    agents_trace[i] = {
                        "agent_idx": i,
                        "temperature": temp_i,
                        "prompt": prompt_i,
                        "prev_answer": (all_round_msgs[-1][i] if all_round_msgs else ""),
                        "references": refs_i,
                        "output": out_s,
                        "tokens": tok_i,
                    }
                    # Print a small preview for progress visibility
                    try:
                        prev_len = int(os.environ.get("COTHINKER_PREVIEW_CHARS", "160"))
                    except Exception:
                        prev_len = 160
                    preview = (out_s[:prev_len] + ("…" if len(out_s) > prev_len else "")) if out_s else "(empty)"
                    self._dbg(f"Agent {i+1} done (round {t+1}): {preview}")

            all_round_msgs.append(msgs_this_round)
            self._dbg(f"Round {t+1}: collected {len(msgs_this_round)} messages")
            sum_text, sum_tokens, sum_prompt = self._summarize(task, summaries[-1], msgs_this_round, llm)
            summaries.append(sum_text)
            self._dbg(f"Round {t+1}: summary mode={self.cfg.summarizer}, tokens={sum_tokens}")
            round_traces.append({
                "round_index": t,
                "agents": agents_trace,
                "summary": {
                    "mode": self.cfg.summarizer,
                    "prompt": sum_prompt,
                    "output": sum_text,
                    "tokens": sum_tokens,
                },
            })

        self._dbg("Synthesis starting")
        final, synth_tokens, synth_prompt = self._synthesize(task, summaries[-1], all_round_msgs[-1], llm)
        self._dbg(f"Synthesis done: tokens={synth_tokens}")

        return {
            "final_answer": final,
            "round_messages": all_round_msgs,
            "summaries": summaries,
            "styles": styles,
            "config": self.cfg.__dict__,
            "trace": {
                "question": task,
                "neighbors": neighbors,
                "style": {
                    "prompt": style_prompt,
                    "tokens": style_tokens,
                },
                "rounds": round_traces,
                "synthesis": {
                    "prompt": synth_prompt,
                    "output": final,
                    "tokens": synth_tokens,
                },
            },
        }


class PromptRegistry:
    def __init__(self, templates: dict[str, str]):
        self.templates = templates

    def style_orchestrator(self, task: str) -> str:
        return self.templates["style_orchestrator"].format(task=task)

    def tms(self, task: str, prev_summary: str, round_msgs: list[str]) -> str:
        return self.templates["tms"].format(task=task, prev_summary=prev_summary, messages="\n\n".join(round_msgs))

    def individual_summarizer(self, task: str, msg: str) -> str:
        return self.templates["individual_summarizer"].format(task=task, message=msg)

    def synthesizer(self, task: str, summary: str, last_msgs: list[str]) -> str:
        return self.templates["synthesizer"].format(task=task, summary=summary, last_messages="\n\n".join(last_msgs))

    def agent_turn(self, agent_idx: int, style: str, task: str, prev_summary: str, prev_answer: str, references: list[str], round_index: int) -> str:
        return self.templates["agent_turn"].format(
            agent_idx=agent_idx + 1,
            style=style,
            task=task,
            prev_summary=prev_summary,
            prev_answer=prev_answer or "(no previous answer)",
            references="\n\n".join(references) if references else "(no references this round)",
            round_index=round_index,
        )
